#-*- coding: utf-8 -*- 

'''
Created on Oct 12, 2010
Decision Tree Source Code for Machine Learning in Action Ch. 3
@author: Peter Harrington
'''
from math import log      
import operator

# #训练数据集
# def createDataSet():
# 	#类别值有两个：yes,no
#     dataSet = [[1, 1, 'yes'],
#                [1, 1, 'yes'],
#                [1, 0, 'no'],
#                [0, 1, 'no'],
#                [0, 1, 'no']]
# 	#两个特征名称
#     labels = ['no surfacing','flippers']
#     #change to discrete values
#     return dataSet, labels

	#年龄：青年1，中年2，老年3
    #有工作：有1，无0
    #有自己的房子：有1，无0
    #信誉度：一般1，好2，非常好3
    #是否同意贷款：同意yes,不同意no
#训练数据集
def createDataSet():
    dataSet = [[1,0,0,1,'no'],
               [1,0,0,2,'no'],
               [1,1,0,2,'yes'],
               [1,1,1,1,'yes'],
               [1,0,0,1,'no'],#
               [2,0,0,1,'no'],
               [2,0,0,2,'no'],
               [2,1,1,2,'yes'],
               [2,0,1,3,'yes'],
               [2,0,1,3,'yes'],#
               [3,0,1,3,'yes'],
               [3,0,1,2,'yes'],
               [3,1,0,2,'yes'],
               [3,1,0,3,'yes'],
               [3,0,0,1,'no']
               ]
	#两个特征名称
    labels = ['age','work','house','credibility']
    return dataSet, labels

#计算熵
def calcShannonEnt(dataSet):
	#样本总数
    numEntries = len(dataSet) 
	#dict字典数据类型，字典是由键对值组组成
    labelCounts = {} 
	#featVec遍历dataSet的每一行
    for featVec in dataSet:    
		#每行最后一列为类别标签
        currentLabel = featVec[-1]
		#***统计各类别中的样本数
		#为所有可能取值建立字典<key,value>结构
		#key表示类别，value表示出现次数
        if currentLabel not in labelCounts.keys(): 
		   #当前key不在字典中，扩展字典
           labelCounts[currentLabel] = 0  
        #当前类别出现一次，字典value值加1
        labelCounts[currentLabel] += 1 
	#保存信息熵
    shannonEnt = 0.0
	#样本遍历完后，计算给类别占总样本数的比例
    for key in labelCounts:
		#各类别所占比例
        prob = float(labelCounts[key])/numEntries 
		#熵计算，自然对数以2为底
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt
	
#*****按照给定特征值划分数据集	
#dataSet：待划分的数据集
#axis：特征索引，按照该特征进行划分
#value：所选特征下的某个取值，需要返回的特征的值
def splitDataSet(dataSet, axis, value):
	#定义新变量，保存划分后的数据集
    retDataSet = []
	#遍历数据集dataSet的每一行(条)数据
    for featVec in dataSet:
		#第axis所对应的特征值为要用于划分的特征值
        if featVec[axis] == value:
			#取特征列[0]~[axis-1]
            reducedFeatVec = featVec[:axis]
			#取特征列[axis+1]~[最后一列]
            reducedFeatVec.extend(featVec[axis+1:])  #扁平化，所有元素都是列表
			#裁剪掉 [axis]列的特征值
            retDataSet.append(reducedFeatVec) 		#加上的元素是一个列表
    return retDataSet
 
#*****选择最好的特征划分数据集，即选信息增益最大的特征划分
#dataSet：待划分的数据集 
def chooseBestFeatureToSplit(dataSet):
	#特征维数
    numFeatures = len(dataSet[0]) - 1  
	#数据集的整体熵
    baseEntropy = calcShannonEnt(dataSet)  
	#保存最大信息增益，初始化为0
    bestInfoGain = 0.0
	#信息增益最大的特征，初始化为-1列特征
    bestFeature = -1   
	#遍历所有特征
    for i in range(numFeatures):  
		#取第i列特征
        featList = [example[i] for example in dataSet] 
		#第i列特征下的所有不重复取值
        uniqueVals = set(featList)
		#条件熵，初始化为0
        newEntropy = 0.0
		#遍历某个特征下的所有值
        for value in uniqueVals:
			#将dataSet按照第i列特征==value划分成子数据集subDataSet
            subDataSet = splitDataSet(dataSet, i, value)
			#按第i列特征划分后的各子集在总数据集dataSet所占的比例
            prob = len(subDataSet)/float(len(dataSet))
			#按照第i列特征划分后的条件熵
            newEntropy += prob * calcShannonEnt(subDataSet)
		#信息增益 = 整体熵 - 条件熵
        infoGain = baseEntropy - newEntropy 
		#找最大信息增益及所对应的特征
        if (infoGain > bestInfoGain):            
            bestInfoGain = infoGain              
            bestFeature = i  #最大信息增益对应的特征
    return bestFeature

#所有特征用完后，以数据集中类别最多的类别作为最终类别
#classList:类别
def majorityCnt(classList):
	#<key,value>字典结构，key存类别，value存类别对应的样本数
    classCount={}
	#遍历数据集中的类别
    for vote in classList:
		#若类别不在字典中，则添加到字典中去
        if vote not in classCount.keys(): classCount[vote] = 0
		#若某类别值出现一次，则累加1,即统计某类别下的样本数
        classCount[vote] += 1
	#按照classCount[vote]值从大到小排序（即各类别下样本数从大到小排序）；operator.itemgetter(1)按第一个域进行排序，结果是数组嵌套多个map
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    #返回含样本数最多的类别
    return sortedClassCount[0][0]

#*****递归创建决策树
#dataSet：数据集
#labels：特征名称
def createTree(dataSet, labels):
	#获取数据集最后一列，即类别标签
    classList = [example[-1] for example in dataSet]
	#*当划分的数据集属于同一类别则停止划分，返回该类别
    if classList.count(classList[0]) == len(classList):
        return classList[0]
	#*划分的数据集已经没有特征值，返回出现次数多的类别
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
	#*递归未终止
	#选出最佳划分的特征所对应的索引，即信息增益最大的划分特征
    bestFeat = chooseBestFeatureToSplit(dataSet)
	#最佳划分的特征名称
    bestFeatLabel = labels[bestFeat]
	#将该特征名称作为根节点
    myTree = {bestFeatLabel:{}}
	#删除在原特征名称中的最佳划分特征名称
    del(labels[bestFeat])
	#取出最佳的划分特征列
    featValues = [example[bestFeat] for example in dataSet]
	#特征列下的不重复特征值（去重）
    uniqueVals = set(featValues)
	#遍历最佳划分特征下的特征值
    for value in uniqueVals:
		#获取删除最佳划分特征名称后的特征名称集合
        subLabels = labels[:] 
		#对划分后的子集进行递归调用构建决策树，将递归调用的结果作为树的一个分支
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree 
	
#*****使用决策树进行分类
#inputTree：训练好的决策树
#featLabels：特征的名称
#testVec：测试数据   
def classify(inputTree, featLabels, testVec):
	#存放决策树的根节点名称
    firstStr = inputTree.keys()[0]
    print 'firstStr=', firstStr
    print 'featLabels=', featLabels
	#除根节点名称外的值
    secondDict = inputTree[firstStr]
    print 'secondDict=',secondDict
	#index方法查找当前列表中第一个匹配firstStr变量的元素的索引
	#即找到树根节点在所有特征列的第几列
    featIndex = featLabels.index(firstStr)
    print 'featIndex=', featIndex
	
	#测试数据对应根节点所在特征下的特征值
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
	#判断valueOfFeat的类型
	#valueOfFeat为dict字典类型，递归寻找
    if isinstance(valueOfFeat, dict): 
        classLabel = classify(valueOfFeat, featLabels, testVec)
	#valueOfFeat为数值，直接返回该值（最终类别）
    else: classLabel = valueOfFeat
	#返回最终类别
    return classLabel

def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()
    
def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)
	

    
#==================================================================
####绘图：绘构建的树
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
    
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

##def createPlot():
##    fig = plt.figure(1, facecolor='white')
##    fig.clf()
##    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
##    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
##    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
##    plt.show()

##def retrieveTree(i):
##    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
##                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
##                  ]
##    return listOfTrees[i]
##
###createPlot(thisTree)
####绘图：绘构建的树
#==================================================================


    
    
if __name__ == "__main__":

    #****简单数据测试
    mydat, labels = createDataSet()
	#赋值后，若featlabels改变，labels不会随之改变
    featlabels = labels[:]
    print 'featlabels=', featlabels

    #创建决策树
    mytree = createTree(mydat, featlabels)
    print mytree
    createPlot(mytree) #绘图：构建的树

    #预测测试样本数据[1,0,1,2]的类别
    # prelabel = classify(mytree, labels, [1, 0, 0, 1])
    # print 'prelabel=', prelabel
    

    # #****隐形眼镜数据集
    fr=open('lenses.txt','r')
    lenses=[line.strip().split('\t')for line in fr.readlines()]
    lensesLabels=['age', 'prescript', 'astigmatic', 'tearRate']
    lensesTree=createTree(lenses, lensesLabels)
    print lensesTree
    createPlot(lensesTree) #绘图：构建的树
    

